Skip to content

Don't unnecessarily wrap the elem in PythonTensor #554

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 3, 2022
Merged

Conversation

ezyang
Copy link
Contributor

@ezyang ezyang commented Mar 2, 2022

Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor
and an FX proxy, a PythonTensor is a regular CPU tensor, that also
carries an FX proxy (that updates as we go along).

Partially addresses #465 and
it also fixed some expected failures in the test suite.

Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor
and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also
carries an FX proxy (that updates as we go along).

This should fix #465 and
it also fixed some expected failures in the test suite.

Signed-off-by: Edward Z. Yang <[email protected]>
@ezyang ezyang requested review from Chillee and zou3519 and removed request for Chillee March 2, 2022 15:39
@zou3519 zou3519 requested a review from Chillee March 2, 2022 15:46
@Chillee
Copy link
Contributor

Chillee commented Mar 2, 2022

I’ll review the rest in a bit, but sadly this doesn’t fix #465 - that’s also a problem for vmap, not just AOTautograd.

Copy link
Contributor

@Chillee Chillee left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More or less fine with this - I feel like this is just sidestepping the real problem though, which is that we have some issues with wrapper tensor subclasses.

# TODO: this might not actually work, I didn't test it when
# I changed device derivation to work off of the types of the
# input devices
args = pytree.tree_map(lambda x: torch.ones_like(x, device=x.device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work, since at this point the args will be meta tensors, and the device will simply be the meta device.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not super opposed to ripping out the meta tensor code entirely though, and reimplementing it some different way if you think there's a better way to do it :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, let's dump the meta tensor code for now and I'll reimplement it shortly

# PythonTensor boundary.
# assert not elem.requires_grad or not torch.is_grad_enabled()

r = torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't work for meta tensors since at this point, elem will be a meta tensor. So we're just gonna make a PythonTensor with the meta device anyways.

That's why I went through all of the shenanigans of inferring the output device - if we run with meta tensors, then at no point do we have the actual output device of the operator. All you have is the device of the input tensors.

So... ripping out the device inference logic will make the meta-tracing stuff not work at all, in which case we should just remove all of it :P

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like a bigger structural problem for meta tensors. Will need to think about this...

ezyang added 2 commits March 2, 2022 17:26
Signed-off-by: Edward Z. Yang <[email protected]>
Signed-off-by: Edward Z. Yang <[email protected]>
@ezyang ezyang merged commit e7444f9 into main Mar 3, 2022
zou3519 pushed a commit to zou3519/pytorch that referenced this pull request Jul 20, 2022
…h/functorch#554)

* Don't unnecessarily wrap the elem in PythonTensor

Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor
and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also
carries an FX proxy (that updates as we go along).

This should fix pytorch/functorch#465 and
it also fixed some expected failures in the test suite.

This kills the meta variant logic entirely; maybe some other time we'll
try to bring it back.

Signed-off-by: Edward Z. Yang <[email protected]>
bigfootjon pushed a commit to pytorch/pytorch that referenced this pull request Jul 21, 2022
…h/functorch#554)

* Don't unnecessarily wrap the elem in PythonTensor

Instead of saying that a PythonTensor has a regular (e.g., CPU) tensor
and an FX proxy, a PythonTensor *is a* regular CPU tensor, that also
carries an FX proxy (that updates as we go along).

This should fix pytorch/functorch#465 and
it also fixed some expected failures in the test suite.

This kills the meta variant logic entirely; maybe some other time we'll
try to bring it back.

Signed-off-by: Edward Z. Yang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants